from collections import Counter
from dataclasses import dataclass
from typing import Dict, List, Set
import json
import os

from helm.benchmark.scenarios.scenario import (
    CORRECT_TAG,
    ALL_SPLITS,
    Instance,
    Input,
    Output,
    Reference,
    Scenario,
)
from helm.common.media_object import MediaObject, MultimediaObject
from helm.common.general import ensure_file_downloaded


@dataclass(frozen=True)
class HEIMHumanEvalReference(Reference):
    # The number of human annotators who gave this reference or answer.
    num_human_answered: int = 0


class HEIMHumanEvalScenario(Scenario):
    """
    In [Holistic Evaluation of Text-To-Image Models (HEIM)](https://arxiv.org/abs/2311.04287), we evaluated 26
    state-of-the-art text-to-image models using across 12 different aspects (e.g., toxicity mitigation, unbiasedness,
    originality, etc.). We used human annotators through AWS Mechanical Turk to evaluate the models for some of
    these aspects (see image below).
    This scenario contains the AI-generated images and human annotations for the following question types:

    1. Alignment
    2. Aesthetics
    3. Clear subject
    4. Originality
    5. Photorealism

    Citations:
    - HEIM: https://arxiv.org/abs/2311.04287
    - MS COCO: https://arxiv.org/abs/1405.0312
    """

    DATASET_DOWNLOAD_URL: str = (
        "https://worksheets.codalab.org/rest/bundles/0x502d646c366c4f1d8c4a2ccf163b958f/contents/blob/"
    )
    VALID_QUESTION_TYPES: Set[str] = {"alignment", "aesthetics", "clear_subject", "originality", "photorealism"}

    name = "heim_human_eval"
    description = (
        "Images generated by text-to-image models and human annotations for HEIM "
        "([paper](https://arxiv.org/abs/2311.04287))."
    )
    tags = ["vision-language", "visual question answering", "image evaluation"]

    def __init__(self, question_type: str):
        super().__init__()
        assert question_type in self.VALID_QUESTION_TYPES, f"Invalid question type: {question_type}"
        self._question_type: str = question_type

    def get_instances(self, output_path: str) -> List[Instance]:
        # Download the dataset
        output_path = os.path.join(output_path, "dataset")
        ensure_file_downloaded(
            source_url=self.DATASET_DOWNLOAD_URL, target_path=output_path, unpack=True, unpack_type="untar"
        )

        # Load the multiple-choice questions
        with open(os.path.join(output_path, "questions.json")) as questions_file:
            question_info: Dict = json.load(questions_file)[self._question_type]

        instances: List[Instance] = []
        for split in ALL_SPLITS:
            annotations_split_path: str = os.path.join(output_path, f"{self._question_type}_{split}.jsonl")
            with open(annotations_split_path) as f:
                # Read each line/example as a JSON object
                for line in f.readlines():
                    image_annotation: Dict = json.loads(line)
                    image_path: str = os.path.join(output_path, image_annotation["image_path"])
                    assert os.path.exists(image_path), f"Image {image_path} does not exist"

                    # Get the most common human answer(s) for the question
                    human_answers: List[str] = [str(answer) for answer in image_annotation["human_annotations"]]
                    human_answers_to_counts = Counter(human_answers)
                    max_count: int = max(human_answers_to_counts.values())
                    modes: List[str] = [value for value, count in human_answers_to_counts.items() if count == max_count]

                    content: List[MediaObject] = [MediaObject(location=image_path, content_type="image/png")]
                    if "prompt" in image_annotation:
                        # Include the prompt in the content if it exists
                        prompt: str = image_annotation["prompt"]
                        content.append(MediaObject(text=f"Description: {prompt}", content_type="text/plain"))
                    content.append(MediaObject(text=question_info["question"], content_type="text/plain"))

                    references: List[Reference] = [
                        HEIMHumanEvalReference(
                            Output(text=answer),
                            # The mode is the most common human answer and the reference we mark as correct
                            tags=[CORRECT_TAG] if value in modes else [],
                            num_human_answered=human_answers_to_counts[value],
                        )
                        for value, answer in question_info["choices"].items()
                    ]
                    instances.append(
                        Instance(
                            Input(multimedia_content=MultimediaObject(content)),
                            references=references,
                            split=split,
                        )
                    )

        return instances
